import absl.app
import absl.flags

import gym
import numpy as np
import chex

import algos
from algos.model import (
  DecoupledQFunction,
  FullyConnectedNetwork,
  FullyConnectedQFunction,
  FullyConnectedVFunction,
  ResTanhGaussianPolicy,
  SamplerPolicyEncoder,
  ClipGaussianPolicy,
  TanhGaussianPolicy,
  IdentityEncoder,
  ResEncoder,
  ResQFunction,
  ResClipGaussianPolicy,
)
# from data import Dataset, RandSampler, SlidingWindowSampler, RLUPDataset,\
#                  DM2Gym, BalancedSampler, PrefetchBalancedSampler
from data import RandSampler, BalancedSampler, PrefetchBalancedSampler
from utilities.jax_utils import batch_to_jax
from utilities.replay_buffer import get_d4rl_dataset, ReplayBuffer
from utilities.sampler import TrajSampler, StepSampler
from utilities.utils import (
  # SOTALogger,
  WandBLogger,
  Timer,
  define_flags_with_default,
  get_user_flags,
  prefix_metrics,
  set_random_seed,
  normalize
)
from viskit.logging import logger, setup_logger
from pathlib import Path
from flax import linen as nn
import tqdm

SOTALogger = WandBLogger

FLAGS_DEF = define_flags_with_default(
  env='walker2d-medium-v2',
  algo='ConservativeSAC',
  max_traj_length=1000,
  seed=42,
  save_model=False,
  batch_size=256,
  reward_scale=1.0,
  reward_bias=0.0,
  clip_action=0.999,
  encoder_arch='64-64',
  policy_arch='256-256',
  qf_arch='256-256',
  orthogonal_init=False,
  policy_log_std_multiplier=1.0,
  policy_log_std_offset=-1.0,
  # n_epochs=2000,
  n_epochs=1000,
  bc_epochs=0,
  n_train_step_per_epoch=1000,
  eval_period=5,
  eval_n_trajs=10,
  eval_smooth=True, # use the average of the last 10 evals
  # configs for trining scheme
  sampler='random', # online, random, m-balanced-reward, m-balanced-return, c-balanced-reward, c-balanced-return,
  reweight=False,
  reweight_eval=False,
  reweight_improve=False,
  reweight_constraint=0, # 0: no reweight; 1: reweight the second term;  reweight two terms
  cql_coff=1.0,
  window_size=3000,  # window size for online sampler
  logging=SOTALogger.get_default_config(),
  top_n=100, # only use the top n% timesteps in dataset ordered by episode return
  sort_by='return', # only use the top n% timesteps in dataset ordered by episode return
  dataset='d4rl',
  rl_unplugged_task_class = 'control_suite',
  use_resnet=False,
  use_layer_norm=False,
  activation='elu',
  replay_buffer_size=1000000,
  n_env_steps_per_epoch=1000,
  online=False,
  normalize_reward=False,
  embedding_dim=64,
  decoupled_q=False,
  # resample
  bc_eval=False,
  weight_num=3,
  weight_ensemble='mean',
  weight_path='',
  iter=1,
  weight_func='linear',
  std=2.0,
  eps=0.1,
  eps_max=0.0,
  exp_lambd=1.0,
  cql_min_q_weight=5.0,
)

class Dataset(object):
  """Dataset."""

  def __init__(self, data: dict) -> None:
    self._data = data
    self._keys = list(data.keys())
    self._sampler = None
    # set weight
    self._keys.append('weights')
    self._data['weights'] = np.ones_like(self._data['rewards'])

  def size(self):
    return len(self._data[self._keys[0]])

  def retrieve(self, indices: np.ndarray):
    "Get a batch of data."
    indexed = {}

    for key in self._keys:
      indexed[key] = self._data[key][indices, ...]
    return indexed

  @property
  def sampler(self):
    assert self._sampler is not None
    return self._sampler

  def set_sampler(self, sampler):
    self._sampler = sampler

  def set_weight(self, weight):
    self._data['weights'] = np.reshape(weight, self._data['weights'].shape)

  def sample(self):
    assert self._sampler is not None

    indices = self._sampler.sample()
    return self.retrieve(indices)

def main(argv):
  FLAGS = absl.flags.FLAGS
  off_algo = getattr(algos, FLAGS.algo)
  try:
    algo_cfg = off_algo.get_default_config({'env': FLAGS.env, 'cql_min_q_weight': FLAGS.cql_min_q_weight,
            'cql_coff': FLAGS.cql_coff})
  except:
    algo_cfg = off_algo.get_default_config()

  if FLAGS.bc_eval and FLAGS.env == 'halfcheetah-medium-expert-v2':
    FLAGS.cql_coff=2.0

  variant = get_user_flags(FLAGS, FLAGS_DEF)
  for k, v in algo_cfg.items():
    variant[f'algo.{k}'] = v

  sota_logger = SOTALogger(
    config=FLAGS.logging, variant=variant, env_name=FLAGS.env
  )
  setup_logger(
    variant=variant,
    exp_id=sota_logger.experiment_id,
    seed=FLAGS.seed,
    base_log_dir=FLAGS.logging.output_dir,
    include_exp_prefix_sub_dir=False
  )

  set_random_seed(FLAGS.seed)

  if FLAGS.dataset == 'd4rl':
    eval_sampler = TrajSampler(
      gym.make(FLAGS.env).unwrapped, FLAGS.max_traj_length
    )
    if FLAGS.online:
      train_sampler = StepSampler(
        gym.make(FLAGS.env).unwrapped, FLAGS.max_traj_length
      )
    else:
      # Build dataset and sampler
      dataset = get_d4rl_dataset(
          eval_sampler.env,
          algo_cfg.nstep,
          algo_cfg.discount,
          top_n=FLAGS.top_n,
          sort_by=FLAGS.sort_by,
        )

      dataset['rewards'
             ] = dataset['rewards'] * FLAGS.reward_scale + FLAGS.reward_bias
      dataset['actions'] = np.clip(
        dataset['actions'], -FLAGS.clip_action, FLAGS.clip_action
      )
      
      if FLAGS.normalize_reward:
        normalize(dataset)

      if 'reward' in FLAGS.sampler:
          dist = dataset['rewards']
      elif 'return' in FLAGS.sampler:
          dist = dataset['returns']
      else:
          dist = dataset['returns'] # default

      dataset = Dataset(dataset)
      if FLAGS.sampler == 'online':
        sampler = SlidingWindowSampler(
          dataset.size(),
          FLAGS.n_epochs * FLAGS.n_train_step_per_epoch,
          FLAGS.window_size,
          FLAGS.batch_size,
        )
      elif 'm-balanced' in FLAGS.sampler: # magnitude balanced
        probs = (dist - dist.min()) / (dist.max() - dist.min())
        # probs = np.sqrt(probs)
        # sampler = BalancedSampler(
        #   probs,
        #   dataset.size(),
        #   FLAGS.batch_size,
        # )
      elif 'c-balanced' in FLAGS.sampler:  # class-balanced
        dist = (dist - dist.min()) / (dist.max() - dist.min()+1e3)  # values between [0,1)
        # labels = np.floor((dist * 10)).astype(np.int32) # 0,1,...,9
        labels = np.floor((dist * 50)).astype(np.int32) # 0,1,...,9
        class_count = {label:len(np.nonzero(labels==label)[0]) for label in np.unique(labels)}
        print(class_count)
        assert np.any(class_count.values())
        probs = np.array([1./np.sqrt(class_count[label]) for label in labels])
      else:
        probs = (dist - dist.min()) / (dist.max() - dist.min())
      probs = probs / probs.sum()

      if FLAGS.bc_eval:
        # weight loading module (filename changed)
        weight_list = []
        for seed in range(1, FLAGS.weight_num + 1):
            try:
                wp = FLAGS.weight_path%seed
            except:
                wp = FLAGS.weight_path # load the speificed weight
            eval_res = np.load(wp, allow_pickle=True).item()
            try:
                num_iter, bc_eval_steps = eval_res['iter'], eval_res['eval_steps']
                assert FLAGS.iter <= num_iter
                t = eval_res[FLAGS.iter]
            except:
                assert FLAGS.iter == 1
                t = eval_res[1000000]['adv']
            weight_list.append(t)
            print(f'Loading weights from {wp} at {FLAGS.iter}th rebalanced behavior policy')
        if FLAGS.weight_ensemble == 'mean':
          weight = np.stack(weight_list, axis=0).mean(axis=0)
        elif FLAGS.weight_ensemble == 'median':
          weight = np.median(np.stack(weight_list, axis=0), axis=0)
        else:
          raise NotImplementedError
        assert dataset.size() == weight.shape[0]
        size = dataset.size()
        if FLAGS.weight_func == 'linear':
          weight = weight - weight.min()
          probs = weight / weight.sum()
          # keep mean, scale std
          if FLAGS.std:
              scale = FLAGS.std / (probs.std() * size)
              probs = scale*(probs - 1/size) + 1/size
              if FLAGS.eps: # if scale, the prob may be negative.
                  probs = np.maximum(probs, FLAGS.eps/size)
              if FLAGS.eps_max: # if scale, the prob may be too large.
                  probs = np.minimum(probs, FLAGS.eps_max/size)
          probs = probs/probs.sum() # norm to 1 again
        elif FLAGS.weight_func == 'exp':
            weight = weight / np.abs(weight).mean()
            weight = np.exp(FLAGS.exp_lambd * weight)
            probs = weight / weight.sum()
        else:
            raise NotImplementedError
      
      if FLAGS.sampler in ['uniform', 'random']:
        sampler = RandSampler(dataset.size(), FLAGS.batch_size)
      else:
        sampler = PrefetchBalancedSampler(
          probs.squeeze(),
          dataset.size(),
          FLAGS.batch_size,
          FLAGS.n_train_step_per_epoch
        )
      if FLAGS.reweight:
        dataset.set_weight(probs*dataset.size())
      dataset.set_sampler(sampler)

  elif FLAGS.dataset == 'rl_unplugged':
    path = Path(__file__).absolute().parent.parent / 'data'
    dataset = RLUPDataset(
      FLAGS.rl_unplugged_task_class,
      FLAGS.env,
      str(path),
      batch_size=FLAGS.batch_size,
      action_clipping=FLAGS.clip_action,
    )

    env = DM2Gym(dataset.env)
    eval_sampler = TrajSampler(
      env, max_traj_length=FLAGS.max_traj_length
    )
    if FLAGS.online:
      train_sampler = StepSampler(
        env, max_traj_length=FLAGS.max_traj_length
      )
  else:
    raise NotImplementedError

  if FLAGS.online:
    replay_buffer = ReplayBuffer(FLAGS.replay_buffer_size)

  observation_dim = eval_sampler.env.observation_space.shape[0]
  action_dim = eval_sampler.env.action_space.shape[0]

  if not FLAGS.use_resnet:
    use_layer_norm = FLAGS.use_layer_norm
    embedding_dim = observation_dim
    encoder = IdentityEncoder(1, embedding_dim, nn.relu)
    if FLAGS.decoupled_q:
      embedding_dim = FLAGS.embedding_dim
      encoder = FullyConnectedNetwork(
        embedding_dim, FLAGS.encoder_arch, FLAGS.orthogonal_init,
        use_layer_norm, 'relu'
      )
    if FLAGS.algo == 'MPO':
      use_layer_norm = True
      policy = ClipGaussianPolicy(
        observation_dim, embedding_dim, action_dim, FLAGS.policy_arch, FLAGS.orthogonal_init,
        FLAGS.policy_log_std_multiplier, FLAGS.policy_log_std_offset, use_layer_norm,
        FLAGS.activation,
      )
    else:
      policy = TanhGaussianPolicy(
        observation_dim, embedding_dim, action_dim, FLAGS.policy_arch, FLAGS.orthogonal_init,
        FLAGS.policy_log_std_multiplier, FLAGS.policy_log_std_offset,
      )

    qf = FullyConnectedQFunction(
      observation_dim,
      action_dim,
      FLAGS.qf_arch,
      FLAGS.orthogonal_init,
      use_layer_norm,
      FLAGS.activation,
    )
    if FLAGS.decoupled_q:
      qf = DecoupledQFunction(
        FLAGS.embedding_dim,
        observation_dim,
        action_dim,
        FLAGS.qf_arch,
        FLAGS.orthogonal_init,
        use_layer_norm,
        FLAGS.activation,
      )
    vf = None
    if FLAGS.algo == 'IQL':
      vf = FullyConnectedVFunction(
        observation_dim,
        FLAGS.qf_arch,
        FLAGS.orthogonal_init,
        use_layer_norm,
        FLAGS.activation,
      )
  else:
    hidden_dim = algo_cfg.res_hidden_size
    encoder_blocks = algo_cfg.encoder_blocks
    encoder = ResEncoder(
      encoder_blocks,
      hidden_dim,
      nn.elu,
    )
    if FLAGS.algo == 'MPO':
      policy = ResClipGaussianPolicy(
        observation_dim,
        action_dim,
        f"{algo_cfg.res_hidden_size}",
        FLAGS.orthogonal_init,
        FLAGS.policy_log_std_multiplier,
        FLAGS.policy_log_std_offset,
        algo_cfg.head_blocks,
      )
    if FLAGS.also == 'IQL':
      raise NotImplementedError
    else:
      policy = ResTanhGaussianPolicy(
        observation_dim,
        action_dim,
        f"{algo_cfg.res_hidden_size}",
        FLAGS.orthogonal_init,
        FLAGS.policy_log_std_multiplier,
        FLAGS.policy_log_std_offset,
        algo_cfg.head_blocks,
      )
    qf = ResQFunction(
      observation_dim,
      action_dim,
      f"{algo_cfg.res_hidden_size}",
      FLAGS.orthogonal_init,
      algo_cfg.head_blocks,
    )

  if algo_cfg.target_entropy >= 0.0:
    algo_cfg.target_entropy = -np.prod(eval_sampler.env.action_space.shape
                                       ).item()

  if FLAGS.algo == 'IQL':
    agent = off_algo(algo_cfg, encoder, policy, qf, vf)
  else:
    agent = off_algo(algo_cfg, encoder, policy, qf, decoupled_q=FLAGS.decoupled_q)
  sampler_policy = SamplerPolicyEncoder(agent, agent.train_params)

  viskit_metrics = {}
  returns = []
  normalized_returns = []
  for epoch in range(FLAGS.n_epochs):
    metrics = {'epoch': epoch}

    if FLAGS.online:
      with Timer() as rollout_timer:
        train_sampler.sample(
          sampler_policy.update_params(agent.train_params),
          FLAGS.n_env_steps_per_epoch, deterministic=False,
          replay_buffer=replay_buffer
        )
        metrics['env_steps'] = replay_buffer.total_steps
        metrics['epoch'] = epoch

    with Timer() as train_timer:
      for _ in tqdm.tqdm(range(FLAGS.n_train_step_per_epoch)):
        if FLAGS.online:
          batch = batch_to_jax(replay_buffer.sample(FLAGS.batch_size))
        else:
          batch = batch_to_jax(dataset.sample())
        metrics.update(
          prefix_metrics(agent.train(batch, FLAGS.reweight_eval, FLAGS.reweight_improve, FLAGS.reweight_constraint, bc=epoch < FLAGS.bc_epochs), 'agent')
        )

    with Timer() as eval_timer:
      if epoch == 0 or (epoch + 1) % FLAGS.eval_period == 0:
        trajs = eval_sampler.sample(
          sampler_policy.update_params(agent.train_params),
          FLAGS.eval_n_trajs,
          deterministic=True
        )

        metrics['average_return'] = np.mean(
          [np.sum(t['rewards']) for t in trajs]
        )
        metrics['average_traj_length'] = np.mean(
          [len(t['rewards']) for t in trajs]
        )
        metrics['average_normalizd_return'] = np.mean(
          [
            eval_sampler.env.get_normalized_score(np.sum(t['rewards']))
            for t in trajs
          ]
        )
  
        # get the average of the final 10 evaluations as performance
        if FLAGS.eval_smooth:
          returns.append(metrics['average_return'])
          normalized_returns.append(metrics['average_normalizd_return'])
          metrics['average_10_return'] = np.mean(
            returns[-min(10, len(returns)):]
          )
          metrics['average_10_normalizd_return'] = np.mean(
            normalized_returns[-min(10, len(returns)):]
          )

        if FLAGS.save_model:
          save_data = {'agent': agent, 'variant': variant, 'epoch': epoch}
          sota_logger.save_pickle(save_data, 'model.pkl')

    if FLAGS.online:
      metrics['rollout_time'] = rollout_timer()
    metrics['train_time'] = train_timer()
    metrics['eval_time'] = eval_timer()
    metrics['epoch_time'] = train_timer() + eval_timer()
    sota_logger.log(metrics)
    viskit_metrics.update(metrics)
    logger.record_dict(viskit_metrics)
    logger.dump_tabular(with_prefix=False, with_timestamp=False)

  if FLAGS.save_model:
    save_data = {'agent': agent, 'variant': variant, 'epoch': epoch}
    sota_logger.save_pickle(save_data, 'model.pkl')


if __name__ == '__main__':
    absl.app.run(main)
